import os

import numpy as np
import scipy
import torch
from sklearn.neighbors import LocalOutlierFactor

from generic.data_util import divide_dataset_according2date, ICEHOCKEY_GAME_FEATURES, ICEHOCKEY_ACTIONS, Transition


def train_lof(agent, debug_mode, sanity_check_msg):
    all_files = sorted(os.listdir(agent.train_data_path))
    training_files, _, _ = divide_dataset_according2date(all_data_files=all_files,
                                                         train_rate=agent.train_rate,
                                                         sports=agent.sports,
                                                         if_split=agent.apply_data_date_div
                                                         )
    if debug_mode:
        training_files = training_files[:2]
    game_data_all = []
    for file_idx in range(len(training_files)):
        file_name = training_files[file_idx]

        s_a_sequence, r_sequence = agent.load_sports_data(game_label=file_name,
                                                          sanity_check_msg=sanity_check_msg)
        pid_sequence = agent.load_player_id(game_label=file_name)
        if agent.apply_rnn:
            transition_all = agent.build_rnn_transitions(s_a_data=s_a_sequence,
                                                         r_data=r_sequence,
                                                         pid_sequence=pid_sequence)
        else:
            transition_all = agent.build_transitions(s_a_data=s_a_sequence,
                                                     r_data=r_sequence,
                                                     pid_sequence=pid_sequence)
        transition_data = Transition(*zip(*transition_all))
        game_data = get_lof_input(agent=agent,
                                  state_action=transition_data.state_action,
                                  trace=transition_data.trace,
                                  apply_history=agent.lof_apply_history,
                                  sanity_check_msg=sanity_check_msg)
        game_data_all.append(game_data)
    game_data_all = np.concatenate(game_data_all, axis=0)
    if agent.lof_metric == 'chebyshev':
        metric = scipy.spatial.distance.chebyshev
    clf = LocalOutlierFactor(n_neighbors=agent.lof_neighbors, metric=metric)
    labels = clf.fit_predict(game_data_all)
    return clf


def get_lof_input(agent, state_action, trace, apply_history, sanity_check_msg=None):
    standard_data_maxs = []
    standard_data_mins = []
    for feature in ICEHOCKEY_GAME_FEATURES + ICEHOCKEY_ACTIONS:
        standard_data_maxs.append((agent.data_maxs[feature] - agent.data_means[feature]) / agent.data_stds[feature])
        standard_data_mins.append((agent.data_mins[feature] - agent.data_means[feature]) / agent.data_stds[feature])
    standard_data_maxs = torch.tensor(standard_data_maxs).to(agent.device)
    standard_data_mins = torch.tensor(standard_data_mins).to(agent.device)
    if sanity_check_msg is None:
        pass
    elif 'location' in sanity_check_msg and 'ha' in sanity_check_msg:
        standard_data_maxs_location = standard_data_maxs[:2]
        standard_data_maxs = standard_data_maxs_location
        standard_data_mins_location = standard_data_mins[:2]
        standard_data_mins = standard_data_mins_location
    else:
        raise ValueError("Unknown sanity_check_msg".format(sanity_check_msg))

    batch_size = len(state_action)
    if apply_history:
        tgt_data = torch.reshape(torch.stack(state_action)[:, :, :], shape=(batch_size, -1))
    else:
        tgt_data = []
        for i in range(batch_size):
            tgt_data.append(state_action[i][trace[i] - 1, :])
        tgt_data = torch.stack(tgt_data)

    return tgt_data.cpu().detach().numpy()
